rm(list=ls())
library(dplyr)
library(FactorHet)
library(glue)
library(combinat)

repdata_nonhisp <- readRDS('code/packaged_data.RDS')

# Evaluate the mean absolute error across columns given a
# permutation matrix
f_perm <- function(A, B, perm){
  objective <- mean(sapply(1:length(A), FUN=function(i){
    sapply(1:length(A), FUN=function(j){
      mean(abs(A[[i]][,perm[i,]] - B[[j]][,perm[j,]]))
    })
  }))
  return(objective)
}

store_all <- data.frame()
path_list <- dir('AOAS_repeat/')

out_mod <- list()
out_AME <- list()
save_output <- list()

for (path in path_list){
  print(path)
  all_data <- readRDS(paste0('AOAS_repeat/', path))
  # Load the AMEs
  all_AME <- lapply(all_data, FUN=function(i){
    i$AME_refit$data
  })
  out_AME[[path]] <- all_AME
  all_mod <- lapply(all_data, FUN=function(i){
    i$moderator
  })
  out_mod[[path]] <- all_mod
  rm(all_AME); rm(all_mod); gc()
}

for (k in c(2,3)){
  print(k)
  all_AME <- lapply(out_AME, FUN=function(i){return(i[[k]])})
  # Make into a "wide" format
  wide_AME <- lapply(all_AME, FUN=function(i){
    i %>% reshape2::dcast(factor + level ~ group, value.var = 'marginal_effect') %>% .[, -1:-2] %>% as.matrix
  })
  # Generate initial permutation matrix
  perm_matrix <- matrix(1:k, nrow = k, ncol = length(wide_AME))
  perm_matrix <- t(perm_matrix)
  # Save original permutation matrix and evaluate error
  orig_perm <- perm_matrix
  objective <- f_perm(wide_AME, wide_AME, perm_matrix)
  # Set number of iterations to do permutations for
  ITER <- 10^4
  # Do this + 5000
  obj_track <- rep(NA, ITER + 5000)
  for (it in 1:(ITER + 5000)){
    if (it %% 1000 == 0){
      print(it)
      # tryCatch(plot(obj_track), error = function(e){NULL})
      print(summary(diff(obj_track[!is.na(obj_track)])))
    }
    # For the first ITER, draw some random number to permute and permute
    # For the last 5,000 only permute one at a time
    if (it > ITER){
      nflip <- 1
    }else{
      nflip <- sample(1:(nrow(perm_matrix)), 1)
    }
    row <- sample(nrow(perm_matrix), nflip)
    new_perm <- perm_matrix
    # Create the new permutation matrix
    for (j in row){
      new_perm[j,] <- sample(1:ncol(perm_matrix))
    }
    # Evaluate the error given this new permutation
    new_obj <- f_perm(wide_AME, wide_AME, new_perm)
    # If improved, then save it
    if (new_obj < objective){
      perm_matrix <- new_perm
      objective <- new_obj
    }
    obj_track[it] <- objective
  }
  # Finally, loop over each row and check all permutations to see if you cna
  # improve it once last time
  objective_cleanup <- objective
  perm_all <- combinat::permn(1:k)
  for (row in 1:nrow(perm_matrix)){
    for (j in perm_all[sample(1:length(perm_all))]){
      new_perm <- perm_matrix
      new_perm[row,] <- j
      new_obj <- f_perm(wide_AME, wide_AME, new_perm)
      if (new_obj < objective){
        perm_matrix <- new_perm
        objective <- new_obj
        print(objective)
      }
    }
  }
  # Parse the data
  data_unperm <- lapply(1:length(wide_AME), FUN=function(i){
    mi <- data.frame(wide_AME[[i]][,orig_perm[i,]])
    names(mi) <- paste0('X', 1:ncol(mi))
    mi <- mi %>% mutate(id = 1:n())}
  ) %>%
    bind_rows(.id = 'sim') %>% reshape2::melt(id = c('sim', 'id'))
  
  data_perm <- lapply(1:length(wide_AME), FUN=function(i){
    mi <- data.frame(wide_AME[[i]][,perm_matrix[i,]])
    names(mi) <- paste0('X', 1:ncol(mi))
    mi <- mi %>% mutate(id = 1:n())}
  ) %>%
    bind_rows(.id = 'sim') %>% reshape2::melt(id = c('sim', 'id'))
  
  output_data <- bind_rows(data_perm %>% mutate(type = 'perm'), data_unperm %>% mutate(type = 'unperm'))
  
  all_AME <- lapply(all_AME, FUN=function(i){i$group <- as.numeric(as.character(i$group)); return(i)})
  orig_all_AME <- bind_rows(all_AME, .id = 'sim')
  out_AME_k <- lapply(1:length(all_AME), FUN=function(i){
    out_i <- all_AME[[i]]
    out_i$group <- match(out_i$group, perm_matrix[i,])
    return(out_i)
  }) %>% bind_rows(.id = 'sim')
  
  all_mod <- lapply(out_mod, FUN=function(i){
    reshape2::melt(i[[k]], id.vars = c('variable', 'type'), variable.name = 'measure') %>%
      separate(measure, into = c('measure', 'group'), sep = '\\.') %>%
      mutate(group = as.numeric(group))
  })
  orig_all_mod <- bind_rows(all_mod, .id = 'sim')
  out_mod_k <- lapply(1:length(all_mod), FUN=function(i){
    out_i <- all_mod[[i]]
    out_i$group <- match(out_i$group, perm_matrix[i,])
    return(out_i)
  }) %>% bind_rows(.id = 'sim')
  
  save_AME <- bind_rows(orig_all_AME %>% mutate(type = 'unperm'), out_AME_k %>% mutate(type = 'perm'))
  save_AME$var <- ifelse(save_AME$var < 0 & save_AME$var > -sqrt(.Machine$double.eps), 0, save_AME$var)

  save_mod <- bind_rows(orig_all_mod %>% mutate(type = 'unperm'), out_mod_k %>% mutate(type = 'perm'))
  
  output <- list(
    save_AME = save_AME,
    save_mod = save_mod,
    output_data = output_data, perm_matrix = perm_matrix,
    obj_track = obj_track
  )
  save_output[[k]] <- output
  
  rm(output, all_AME, save_AME, out_AME_k, orig_all_AME, perm_matrix, out_mod_k, orig_all_mod); gc()
}

saveRDS(save_output, 'final_output/repeat_final_output.RDS')
